import numpy as np
import math
import pandas as pd
import xarray as xr
import os
from scipy import interpolate
from scipy.spatial.distance import cdist
from netCDF4 import Dataset
import sys
import datetime
import gc
import h5py
from datetime import datetime,timedelta
sys.path.append('./')
import multiprocess as mp
import tqdm
#

###### 
#####
from interp_radar import interp_radar_shift

def ptqvapor2rh(pin,tin,qvin) :
    cpd=1004.62
    Lv=2.5104e6
    p0=1000.   #hPa
    t0=273.15    #T0
    R=287.04
    kap= R/cpd

    if np.any(tin < 60) :
       t=tin+t0
    else :
       t=tin
    if np.any(pin > 10000) :
       p=pin/100.
    else :
       p=pin
    rh=np.full(p.shape,np.nan,dtype=np.float32)
    E=6.112*np.exp(17.67*(t-273.16)/(t-29.65))
    e=qvin*p/(qvin*(1-0.622)+0.622)
    rh=e/E*100
#    rh[rh>100.]=100.
    rh[rh<0.]=0.
#                Ew=6.112*exp(17.67*(tw-273.16)/(tw-29.65))
#                ew=q1*minlevel/(q1*(1-0.622)+0.622)
#                rhw=ew/Ew*100

    return rh

max_1D_vars=1
start_ens=0
max_ens=1
max_ens+=start_ens
mesovarnum=4

############################################
shift_max=3
#dbz_threshold=35
dbz_threshold=18
#out_maxelev=maxelev*(2*shift_max+1)*(2*shift_max+1)

#kargindex=160*9
kargindex=20

############################################
dpis=200
font_size=24
#fig_size=(12,18)
#fig_size=(16,9)
fig_size=(28,16)

MISSING=-9999.

threshold= -80.
threshold_PMR= 18.

linewd=3

#D1_3DVAR=False
D1_3DVAR=True
plot_flag=True
#plot_flag=True

###########
#varname=np.array(['dBZ','VEL','KDP','ZDR','RHOHV'])
varname=np.array(['dBZ','KDP','ZDR','RHOHV'])
Rvarname=np.array([1.5,0.2,0.2,0.01])
mesovarname=np.array(['Qv','t','rh','Qr'])
#mesovarname=np.array(['rh'])
#mesovarname=np.array(['pressure','QVAPOR','temp','rh','QRAIN','QNRAIN','QGRAUP'])

obsvarname = {'dBZ':'grdtiltref', 'VEL': 'grdtiltvel', 'KDP': 'grdtiltkdp','ZDR': 'grdtiltzdr', 'RHOHV':'grdtiltrhv' }
B_varname = {'dBZ':'ZH', 'VEL': 'grdtilthighvel', 'KDP': 'kdp','ZDR': 'zdr', 'RHOHV':'rhohv' }
###########

#vdate = datetime(2023,8,20,15,00,00)
vdate = datetime(2023,7,28,0,0,0)

##########
#obs
##########
filename='./FY3G_PMR.HDF'
f=h5py.File(filename,'r')
height=f['/Geolocation/Ku/height'][:]
Latitude=f['/Geolocation/Ku/Latitude'][:]
Longitude=f['/Geolocation/Ku/Longitude'][:]
zFactorMeasured=f['/PRE/Ku/zFactorMeasured'][:]
#print(zFactorMeasured)
print(zFactorMeasured.shape)
flagPrecip=f['/PRE/Ku/flagPrecip'][:]
binClutterFreeBottom=f['/PRE/Ku/binClutterFreeBottom'][:]
dayCount=f['/Geolocation/Ku/dayCount'][:]
msCount=f['/Geolocation/Ku/msCount'][:]

datebase= datetime(2000, 1, 1, 12, 0 , 0)
deltaday = timedelta(hours=24)
deltahour = timedelta(hours=1)
deltaseconds = timedelta(seconds=1)
millidelta = timedelta(milliseconds=1)

Longitude_2d=Longitude.mean(axis=2)
Latitude_2d=Latitude.mean(axis=2)

datestart=datebase+deltaday*dayCount[0]+deltaseconds*msCount[0]/10000.
dateend=datebase+deltaday*dayCount[-1]+deltaseconds*msCount[-1]/10000.
#datestart=datebase+deltaday*dayCount[0]+millidelta*msCount[0]
#dateend=datebase+deltaday*dayCount[-1]+millidelta*msCount[-1]

dims=zFactorMeasured.shape
nscan=dims[0]
nray=dims[1]
nbin=dims[2]
print(dims)

ds=xr.Dataset(
    data_vars=dict(
        zFactorMeasured=(["nscan","nray","nbin"],zFactorMeasured),
        height=(["nscan","nray","nbin"],height),
        binClutterFreeBottom=(["nscan","nray"],binClutterFreeBottom),
        flagPrecip=(["nscan","nray"],flagPrecip),
    ),
    coords=dict(
        Longitude=(["nscan","nray"],Longitude_2d),
        Latitude=(["nscan","nray"],Latitude_2d),
        nbin=(["nbin"],np.arange(nbin)),
        nscan=(["nscan"],np.arange(nscan)),
        nray=(["nray"],np.arange(nray)),
    ),
)
cond=(ds['zFactorMeasured'].nbin>ds['binClutterFreeBottom'])
#ds['zFactorMeasured']=xr.where(cond,-9999,ds['zFactorMeasured'])
ds['zFactorMeasured']=xr.where(cond,np.nan,ds['zFactorMeasured'])

cond=(ds['flagPrecip']==1)
ds['zFactorMeasured']=xr.where(cond,ds['zFactorMeasured'],np.nan)

#zFactorMeasured_l=ds['zFactorMeasured'].max(dim="nbin",skipna=True)

Rvar=Rvarname[0]
obs_all=np.array(ds['zFactorMeasured']/Rvar)

radar_Z_out_all= np.array(ds['height'])

#obs "nscan","nray","nbin"  (3885, 59, 500)

#############
# B_meso
############
debug=False
#debug=True
debug1=True

vertical_flag=1  # 0  sigma 1 pressure 2 Z


# chongqing case
istart  = 230
iend    = 450
jstart  = 100
jend    = 250

#(slice(130,270), slice(230, 450))


# small shift region  # use in swath background HX region
# profile base region
ishift_start_b= 3
ishift_end_b= 217
jshift_start_b= 3
jshift_end_b= 147

#ishift_start= 120
#ishift_end= 130
#jshift_start= 100
#jshift_end= 110  
#  used to define output region retrieve region
#ishift_start= 3
#ishift_end= 217
#jshift_start= 3
#jshift_end= 147 
ishift_start= 3
ishift_end= 217
jshift_start= 3
jshift_end=  147

we=iend-istart
sn=jend-jstart
we_out=ishift_end-ishift_start
sn_out=jshift_end-jshift_start
B_dir='./'

for id_ens in np.arange(start_ens,max_ens):
    ensmem='gep'+str(id_ens).zfill(2)

    B_meso_fullname=os.path.join(B_dir,ensmem,"postvar.nc")
    print(B_meso_fullname)
    B_mesoout=Dataset(B_meso_fullname,"r",format="NETCDF4")

    for id_mesovar in np.arange(0,mesovarnum):
        varname_meso=mesovarname[id_mesovar]
        if varname_meso=='rh' :
            meso_t = B_mesoout.variables['t'][0,:,:,:]
            meso_Qv = B_mesoout.variables['Qv'][0,:,:,:]
            meso_p_0 = B_mesoout.variables['levels'][:]
            meso_p_1 = np.repeat(meso_p_0[:,None],meso_t.shape[1],axis=1)
            meso_p = np.repeat(meso_p_1[:,:,None],meso_t.shape[2],axis=2)
            meso_value_tmp1 = ptqvapor2rh(meso_p,meso_t,meso_Qv)
            del meso_t
            del meso_Qv
            del meso_p_0
            del meso_p_1
            del meso_p
            gc.collect()
        else :
            meso_value_tmp1 = B_mesoout.variables[varname_meso][0,:,:,:]
        if debug :
            print("meso_value_tmp1.shape")
            print(meso_value_tmp1.shape)
    
        if id_mesovar==0 :
            meso_value=np.copy(meso_value_tmp1[...,None])
        else :
            meso_value=np.concatenate((meso_value,meso_value_tmp1[...,None]),axis=-1)
        if debug :
            print("meso_value.shape")
            print(meso_value.shape)

    if id_ens==start_ens :
        meso_all=np.copy(meso_value[:,None,...])
#            print(meso_all.shape)
    else :
        meso_all=np.concatenate((meso_all,meso_value[:,None,...]),axis=1)
#            print(meso_all.shape)
        if debug :
            print(meso_all.shape)


#base profile box
istart_shift=ishift_start_b + istart
iend_shift=ishift_end_b + istart
jstart_shift=jshift_start_b + jstart
jend_shift=jshift_end_b + jstart


##profile_base=meso_value[...,jstart_shift:jend_shift,istart_shift:iend_shift,:].reshape([meso_value.shape[0],-1,meso_value.shape[-1]])
profile_base=meso_all[...,jstart_shift:jend_shift,istart_shift:iend_shift,:].reshape([meso_all.shape[0],-1,meso_all.shape[-1]])
#profile_base=meso_all[...,jstart+shift_max:jend-shift_max-1,istart+shift_max:iend-shift_max-1,:].reshape([meso_all.shape[0],-1,meso_all.shape[-1]])
#profile_base=meso_value[...,jstart:jend,istart:iend,:].reshape([meso_value.shape[0],-1,meso_value.shape[-1]])
#profile_base=meso_all[...,jstart:jend,istart:iend,:].reshape([meso_all.shape[0],-1,meso_all.shape[-1]])
#lat_meso=meso.getvar(B_mesoout, varname_meso,timeidx=0,meta=False)
#lon_meso=meso.getvar(B_mesoout, varname_meso,timeidx=0,meta=False)

#        float latitudes(lat) ;
#                latitudes:units = "degree_north" ;
#        float longitudes(lon) ;
#                longitudes:units = "degree_east" ;
XLAT=B_mesoout.variables['latitudes']
XLONG=B_mesoout.variables['longitudes']

lat=XLAT[jstart:jend]
lon=XLONG[istart:iend]

lon_meso=np.repeat(lon[None,:],len(lat),axis=0)
lat_meso=np.repeat(lat[:,None],len(lon),axis=1)

# meso_value   (nz,ny,nx,nvar)
# meso_all     (nz,nens,ny,nx,nvar)

###################
# from ZJU_AERO
###################
for id_ens in np.arange(start_ens,max_ens):
    ensmem='gep'+str(id_ens).zfill(2)
 
#B_ZJU_fullname='./swath_spaceborn.nc'
    B_ZJU_fullname=os.path.join(B_dir,ensmem,"swath_spaceborn.nc")
    print(B_ZJU_fullname)
    
    B_ds=xr.open_dataset(B_ZJU_fullname)
    
    cond=(B_ds['heights'] > threshold)
    heights=xr.where(cond,B_ds['heights'],MISSING)    # m
    #height=o_ds['height']
    
    #ZH=B_ds['ZH']
    Rvar=Rvarname[0]
    #cond=(B_ds['ZH'] > threshold_PMR)
    #B_value=xr.where(cond,B_ds['ZH'],np.nan)
    cond=(B_ds['ZH'] < 0.)
    B_value=xr.where(cond,0.,B_ds['ZH'])
    B_value=B_ds['ZH']/Rvar
    
    model_Value_in_tmp1 = np.moveaxis(np.array(B_value.fillna(MISSING)), -1,0)[:,None,:,:]
#   add to used nvar   for ZH KDP ZDR and so on
#    model_Value_in_tmp1 = np.moveaxis(np.array(B_value.fillna(MISSING)), -1,0)[:,None,None,:,:]
#    model_Z_in = np.moveaxis(np.array(heights),-1,0)[:,None,:,:]
    model_Z_in_tmp = np.moveaxis(np.array(heights),-1,0)
#   change to (nz,ny,nx)

    if id_ens==start_ens :
        model_Z_in = np.copy(model_Z_in_tmp[:,None,...])
        model_Value_in=np.copy(model_Value_in_tmp1[:,:,None,...])
    else :
        model_Z_in=np.concatenate((model_Z_in,model_Z_in_tmp[:,None,...]),axis=1)
        model_Value_in=np.concatenate((model_Value_in,model_Value_in_tmp1[:,:,None,...]),axis=2)


print("model_Value_in")
print(model_Value_in.shape)

## model_Z_in     (nz,ens,ny,nx)     store ZH 
## model_value_in     (nz,nvar,nens,ny,nx)   store ZH KDP ZDR and so on
#

# model_value_in     (nz,nvar,nens,ny,nx)
if debug :
    print("model_Value_in.shape")
    print(model_Value_in.shape)
    print("model_Z_in.shape")
    print(model_Z_in.shape)

#jstart_shift=shift_max
#istart_shift=shift_max
#jend_shift=sn-shift_max-1
#iend_shift=we-shift_max-1

istart_shift=ishift_start_b
iend_shift=ishift_end_b
jstart_shift=jshift_start_b
jend_shift=jshift_end_b


for jj in np.arange(-shift_max,shift_max+1) :
    for ii in np.arange(-shift_max,shift_max+1) :
        jst=jstart_shift+jj
        jed=jend_shift+jj
        ist=istart_shift+ii
        ied=iend_shift+ii
        if (ii==-shift_max)  :
            model_Value_base_tmp=np.copy(model_Value_in[:,:,None,None,:,jst:jed,ist:ied])
            model_Z_base_tmp=np.copy(model_Z_in[:,None,None,:,jst:jed,ist:ied])
        else :
            if debug :
                print("model_Value_base_tmp.shape")
                print(model_Value_base_tmp.shape)
                print("model_Value_in.shape")
                print(model_Value_in[:,:,None,None,:,jst:jed,ist:ied].shape)
            model_Value_base_tmp=np.concatenate((model_Value_base_tmp,model_Value_in[:,:,None,None,:,jst:jed,ist:ied]),axis=3)
            model_Z_base_tmp=np.concatenate((model_Z_base_tmp,model_Z_in[:,None,None,:,jst:jed,ist:ied]),axis=2)
    if (jj==-shift_max) :
        model_Value_base1=np.copy(model_Value_base_tmp)
        model_Z_base1=np.copy(model_Z_base_tmp)
    else :
        model_Value_base1=np.concatenate((model_Value_base1,model_Value_base_tmp),axis=2)
        model_Z_base1=np.concatenate((model_Z_base1,model_Z_base_tmp),axis=1)

model_Value_formax=model_Value_in[...,jstart_shift:jend_shift,istart_shift:iend_shift].reshape([model_Value_in.shape[0],model_Value_in.shape[1],-1])
nanmax_Value=np.nanmax(model_Value_formax[:,0,:],axis=0)
model_Value_base0=model_Value_base1.reshape([model_Value_base1.shape[0],model_Value_base1.shape[1],model_Value_base1.shape[2],model_Value_base1.shape[3],-1])
if debug :
    print("model_Value_base1.shape")
    print(model_Value_base1.shape)
    print(np.nanmax(model_Value_formax[:,0,...]))

####  
#### model_Value_base1.shape
#### (nz, nvar, shift_j, shift_i,nens,jst-jed,ist-ied)

if debug :
    print("nanmax_Value.shape")
    print(nanmax_Value.shape)
    print(np.nanmax(nanmax_Value))
model_Value_base2=model_Value_base0[:,:,:,:,nanmax_Value>dbz_threshold/Rvarname[0]]
if debug :
    print("model_Value_base2.shape")
    print(model_Value_base2.shape)
del model_Value_formax
del model_Value_base1
del model_Value_base0
#del model_Value_in
#del nanmax_Value
gc.collect()

# model_Z_in     (nz,nens,ny,nx)
model_Z_base0=model_Z_base1.reshape([model_Z_base1.shape[0],model_Z_base1.shape[1],model_Z_base1.shape[2],-1])
model_Z_base2=model_Z_base0[:,:,:,nanmax_Value>dbz_threshold/Rvarname[0]]
del model_Z_base1
del model_Z_base0
#del model_Z_in
#del nanmax_Value
gc.collect()


if debug :
    print("model_Z_base2.shape")
#    print(model_Z_base2)
    print(model_Z_base2.shape)
#   (50, 3, 3, 3427)

profile_base0=profile_base[:,nanmax_Value>dbz_threshold/Rvarname[0],:]
if debug :
    print("profile_base0.shape")
    print(profile_base0.shape)
    print("profile_base.shape")
    print(profile_base.shape)


##########
#
# OBS remap horizontal to WRF
#
#########
#js_slice=960
#je_slice=1150
js_slice=2750
je_slice=2900
is_slice=0
ie_slice=59
we_s= ie_slice-is_slice
sn_s= je_slice-js_slice
obs=obs_all[js_slice:je_slice,is_slice:ie_slice,:]
radar_Z_out_PMR=radar_Z_out_all[js_slice:je_slice,is_slice:ie_slice,:]

lat_gpm=Latitude_2d[js_slice:je_slice,is_slice:ie_slice]
lon_gpm=Longitude_2d[js_slice:je_slice,is_slice:ie_slice]

#iscan=(3051+3400-1)/2
iscan=2800-js_slice
iray=25-is_slice
bin_res2=min((lat_gpm[iscan+1,iray]-lat_gpm[iscan,iray])**2+(lon_gpm[iscan+1,iray]-lon_gpm[iscan,iray])**2,
        (lat_gpm[iscan,iray+1]-lat_gpm[iscan,iray])**2+(lon_gpm[iscan,iray+1]-lon_gpm[iscan,iray])**2)
bin_res=np.sqrt(bin_res2)
bin_res_m=bin_res*100.*1000.

if debug :
    print("bin_res2")
    print(bin_res2)

obs_mesh_meso=np.full(shape=([obs.shape[-1],sn,we]),fill_value=MISSING)
radar_Z_out=np.full(shape=([obs.shape[-1],sn,we]),fill_value=MISSING)
#D1_out_data=np.full(shape=([profile_base0.shape[0],profile_base0.shape[-1],sn,we]),fill_value=MISSING)
out_data_shape=(profile_base0.shape[0],profile_base0.shape[-1],sn,we)
for j in np.arange(0,sn) :
    for i in np.arange(0,we) :
        indi = np.unravel_index(np.argmin((lat_meso[j,i]-lat_gpm)**2 + (lon_meso[j,i]-lon_gpm)**2, axis=None), lat_gpm.shape) 
        dis=(lat_meso[j,i]-lat_gpm[indi])**2 + (lon_meso[j,i]-lon_gpm[indi])**2
        if (dis < 4.*bin_res2)  :
            obs_mesh_meso[:,j,i]=obs[indi[0],indi[1],::-1]
            radar_Z_out[:,j,i]=radar_Z_out_PMR[indi[0],indi[1],::-1]


#radar_Z_out[radar_Z_out>6000.]=np.nan
#obs_mesh_meso[radar_Z_out>6000.]=np.nan

if debug :
    print("max of obs_mesh_meso")
    print(np.nanmax(obs_mesh_meso))
    print(np.nanmax(obs))
    print("nanmax")
    print(np.nanmax(model_Value_base2))
    print(np.nanmax(model_Z_base2))
    print(np.nanmax(radar_Z_out))
#########################################################################

   
#if D1_3DVAR :
if __name__ == "__main__" :

    obs_mesh_meso[obs_mesh_meso < threshold_PMR/Rvarname[0] ]= np.nan
#    obs_mesh_meso[obs_mesh_meso < -999. ]= np.nan
#############
#    obs_mesh_meso[obs_mesh_meso < 0. ]= 999.
#    obs_mesh_meso[obs_mesh_meso > 998. ]= 0.
#    model_Value_base2[model_Value_base2<0.]=0.
    #obs_mesh_meso[obs_mesh_meso<0.]= 0.0 
#    obs_mesh_meso[obs_mesh_meso < threshold_PMR/Rvarname[0] ]= np.nan
#    obs_mesh_meso[obs_mesh_meso < -999. ]= np.nan
    obs_mesh_meso[obs_mesh_meso< 0. ]= 0
    model_Value_base2[model_Value_base2<0.]=0.
    
########################################################################################
    def worker(args):
        jj,ii = args
        if np.isnan(obs_mesh_meso[:,jj,ii]).all() :
            out_data=np.full(shape=(profile_base0.shape[0]*profile_base0.shape[-1]),fill_value=MISSING)
        else :
            tmpsr=obs_mesh_meso[:,jj,ii]
            #if np.nanmax(obs_mesh_meso[:,jj,ii]) > dbz_threshold/Rvarname[0] :
            if np.count_nonzero(tmpsr > dbz_threshold/Rvarname[0]) > 10 :
                if debug :
                    print("calculate  %i and %i" % (ii,jj))
                    print(np.nanmax(obs_mesh_meso[:,jj,ii]))
    
                obs_3D=np.array(obs_mesh_meso[:,jj-shift_max:jj+shift_max+1,ii-shift_max:ii+shift_max+1])
                high_3D=radar_Z_out[:,jj-shift_max:jj+shift_max+1,ii-shift_max:ii+shift_max+1]
         
                base_1D_3D=interp_radar_shift(model_value_in=model_Value_base2.T, model_z_in=model_Z_base2.T, 
                        radar_z_out=high_3D.T,linlog=1,missing=MISSING,debug_flag=False).T
    #            radar_out(nx,ny,ens,nx1,ny1,nvar,maxelev)
    #            radar_out(np,nx1,ny1,nvar,maxelev)
    
                obs_1D=obs_3D.reshape([-1,1])
                base_1D_1D=np.array(base_1D_3D.reshape([-1,base_1D_3D.shape[-1]]))
                 
                if debug :
                    print("obs_1D.shape")
                    print(obs_1D.shape)
                    print("base_1D_1D.shape")
                    print(base_1D_1D.shape)
                    print("base_1D_3D.npmax")
                    print(np.nanmax(base_1D_3D))
      
    
                cond=(base_1D_1D > threshold)
                base_1D_1D=np.where(cond,base_1D_1D,np.nan)
                #if debug :
                   # print(base_1D_1D[~np.isnan(obs_1D).any(axis=1),:])
    
                if np.isnan(base_1D_1D).all() :
                    print("bases are all nan xxx")
                    exit()
    
                if np.isnan(base_1D_1D[~np.isnan(obs_1D).any(axis=1),:]).all() :
                    print("bases are all nan")
                    exit()
    
    ##########      cdist need value not to be np.nan
    #                obs_1D[np.isnan(base_1D_1D).any(axis=1),:]=np.nan
    ##########

                lenave=len(obs_1D[~np.isnan(obs_1D).any(axis=1),:])
                if debug :
                     print("lenave")
                     print(lenave)
                     print(obs_1D[~np.isnan(obs_1D).any(axis=1),:])
                obs_1D[np.isnan(base_1D_1D).all(axis=1),:]=np.nan
                base_1D_1D[np.isnan(base_1D_1D)]=-10.
                
                dist_dBZ=cdist(obs_1D[~np.isnan(obs_1D).any(axis=1),:].T,
                           base_1D_1D[~np.isnan(obs_1D).any(axis=1),:].T,metric='sqeuclidean')/lenave
    
                if (np.isnan(dist_dBZ).all()) :
                    print(obs_1D)
                    print(obs_1D[~np.isnan(obs_1D).any(axis=1)])
                    print("base exit()")
                    print(base_1D_1D[~np.isnan(obs_1D).any(axis=1),0:2])
                    exit()
    
    
                dist_dBZ_1=dist_dBZ.reshape([-1])
    
                index_array=np.argpartition(dist_dBZ_1,kargindex)[0:kargindex]
                if debug1 :
                    print("dist_dBZ_1")
                    print(dist_dBZ_1[index_array])
                    #print("base_1D_1D")
                   # print(base_1D_1D[index_array[0],:])
     
                #    if np.max(dist_dBZ_1) > 1000. :
                #        print("obs_1D")
                #        print(obs_1D[~np.isnan(obs_1D).any(axis=1),:])
                #        print(base_1D_1D[~np.isnan(obs_1D).any(axis=1),index_array[0]])
                #        exit()
    
                dist_dBZ_use=dist_dBZ_1[index_array]
                dist_min=np.nanmin(dist_dBZ_use)
                dist_dBZ_use-=dist_min

                dist_dBZ_W=np.exp(-0.5*dist_dBZ_use)
                if debug1 :
                    print("dist_dBZ_use")
                    print(dist_dBZ_use)
                    print("dist_dBZ_W")
                    print(dist_dBZ_W)
    
                masked_data=np.ma.masked_array(profile_base0[:,index_array,:],np.isnan(profile_base0[:,index_array,:]))
                weights=np.tile(dist_dBZ_W.reshape([1,-1,1]),(profile_base0.shape[0],1,profile_base0.shape[-1]))
                profile_1D=np.ma.average(masked_data,axis=1,weights=weights)
 
                print(profile_1D)
    
                #D1_out_data[:,:,jj,ii]=profile_1D[:,:]
                out_data=profile_1D.reshape(-1)
            else :         
                out_data=np.full(shape=(profile_base0.shape[0]*profile_base0.shape[-1]),fill_value=MISSING)

        return out_data

########################################################################################
 
    parallel=True 
    list_sweep=[]

    print("start parallel:")

    if parallel:
        pool = mp.Pool(processes=mp.cpu_count())
        #npixel= we-shift_max*2
        npixel= we_out
#jshift_start= 40
#jshift_end= 200

#    for jj in tqdm.tqdm(range(shift_max,sn-shift_max)) :
    for jj in tqdm.tqdm(range(jshift_start,jshift_end)) :
#        for iscan in tqdm.tqdm(range(nscan)):
        jj_1scan = np.repeat(jj, npixel)
        ii_1scan = np.arange(ishift_start,ishift_end)
        params_generator = zip(jj_1scan,ii_1scan)
        # [1]. Parallel (by multiprocessing)
        if parallel:
            ppi_sweep = pool.map(worker, params_generator)
        # [2]. Serial
#        else:
#            scan_radials = map(worker, params_generator)

        list_sweep.append(list(ppi_sweep)) # (298 168) =(298 dx,28 nl,6 nvar)
#        print(np.array(list_sweep).shape)  # (n,298 168) =(298 dx,28 nl,6 nvar)

    if parallel:
        pool.close()
        pool.join()

    #D1_out_data=np.moveaxis(np.array(list_sweep).reshape([sn-shift_max*2,we-shift_max*2,profile_base.shape[0],profile_base.shape[-1]]), [0,1],[-2,-1] )
    D1_out_data=np.moveaxis(np.array(list_sweep).reshape([sn_out,we_out,profile_base.shape[0],profile_base.shape[-1]]), [0,1],[-2,-1] )


#    D1_out_data=np.full(shape=[profile_base.shape[0],profile_base.shape[-1],sn,we],fill_value=MISSING)
#    D1_out_data[:,:,jj,ii]=profile_1D[:,:]

    da = xr.DataArray(
        data=D1_out_data,
        dims=["bt", "var", "sn","we"],
        coords=dict(
            bt=np.arange(D1_out_data.shape[0]),
            var=np.arange(D1_out_data.shape[1]),
            sn=np.arange(D1_out_data.shape[2]),
            we=np.arange(D1_out_data.shape[3])
        ),
        attrs=dict(
            description="D1_out",
            units="None",
        ),
    )
    ds=da.to_dataset(name = 'D1_out')
    ds.to_netcdf("tmp.nc")

